{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 前端模型转换"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import set_env"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "定义简单模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "class Conv2D1(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.conv = torch.nn.Conv2d(3, 6, 7, bias=True)\n",
    "\n",
    "    def forward(self, data):\n",
    "        return self.conv(data)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## {class}`~tvm.contrib.msc.core.ir.graph.MSCGraph` 与 PyTorch 模型互转"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "from torch.nn import Module\n",
    "\n",
    "import tvm.testing\n",
    "from tvm.contrib.msc.framework.torch.frontend import translate\n",
    "from tvm.contrib.msc.framework.torch import codegen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "shape = 1, 3, 224, 224\n",
    "input_info = [(shape, \"float32\")]\n",
    "torch_model = Conv2D1()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`torch` 模型转换为 {class}`~tvm.contrib.msc.core.ir.graph.MSCGraph`:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph, weights = translate.from_torch(torch_model, input_info, via_relax=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "{class}`~tvm.contrib.msc.core.ir.graph.MSCGraph` 再转换会 `torch` 模型："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/media/pc/data/lxw/ai/tvm/python/tvm/contrib/msc/framework/torch/codegen/codegen.py:74: FutureWarning: You are using `torch.load` with `weights_only=False` (the current default value), which uses the default pickle module implicitly. It is possible to construct malicious pickle data which will execute arbitrary code during unpickling (See https://github.com/pytorch/pytorch/blob/main/SECURITY.md#untrusted-models for more details). In a future release, the default value for `weights_only` will be flipped to `True`. This limits the functions that could be executed during unpickling. Arbitrary objects will no longer be allowed to be loaded via this mode unless they are explicitly allowlisted by the user via `torch.serialization.add_safe_globals`. We recommend you start setting `weights_only=True` for any use case where you don't have full control of the loaded file. Please open an issue on GitHub for any issues related to this experimental feature.\n",
      "  state_dict = torch.load(folder.relpath(graph.name + \".pth\"))\n"
     ]
    }
   ],
   "source": [
    "model = codegen.to_torch(graph, weights)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "验证一致性："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch_datas = [torch.from_numpy(np.random.rand(*i[0]).astype(i[1])) for i in input_info]\n",
    "with torch.no_grad():\n",
    "    golden = torch_model(*torch_datas)\n",
    "with torch.no_grad():\n",
    "    if not graph.get_inputs():\n",
    "        result = model()\n",
    "    else:\n",
    "        result = model(*torch_datas)\n",
    "if not isinstance(golden, (list, tuple)):\n",
    "    golden = [golden]\n",
    "if not isinstance(result, (list, tuple)):\n",
    "    result = [result]\n",
    "assert len(golden) == len(result), f\"golden {len(golden)} mismatch with result {len(result)}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for gol_r, new_r in zip(golden, result):\n",
    "    if isinstance(gol_r, torch.Tensor):\n",
    "        tvm.testing.assert_allclose(\n",
    "            gol_r.detach().numpy(), new_r.detach().numpy(), atol=1e-5, rtol=1e-5\n",
    "        )\n",
    "    else:\n",
    "        assert gol_r == new_r"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换为 relay"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _valid_target(target):\n",
    "    if not target:\n",
    "        return target\n",
    "    if target == \"ignore\":\n",
    "        return None\n",
    "    if target == \"cuda\" and not tvm.cuda().exist:\n",
    "        return None\n",
    "    if isinstance(target, str):\n",
    "        target = tvm.target.Target(target)\n",
    "    return target"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "def _run_relax(relax_mod, target, datas):\n",
    "    relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)\n",
    "    with tvm.transform.PassContext(opt_level=3):\n",
    "        relax_exec = tvm.relax.build(relax_mod, target)\n",
    "        runnable = tvm.relax.VirtualMachine(relax_exec, tvm.cpu())\n",
    "    res = runnable[\"main\"](*datas)\n",
    "    if isinstance(res, tvm.runtime.NDArray):\n",
    "        return [res.asnumpy()]\n",
    "    return [e.asnumpy() for e in res]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tvm.relax.frontend.torch import from_fx\n",
    "from tvm.relay.frontend import from_pytorch\n",
    "from torch import fx\n",
    "from tvm.contrib.msc.core.frontend import translate\n",
    "from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_config = None\n",
    "codegen_config = None \n",
    "build_target=None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "graph_model = fx.symbolic_trace(torch_model)\n",
    "with torch.no_grad():\n",
    "    expected = from_fx(graph_model, input_info)\n",
    "expected = tvm.relax.transform.CanonicalizeBindings()(expected)\n",
    "\n",
    "# graph from relay\n",
    "datas = [np.random.rand(*i[0]).astype(i[1]) for i in input_info]\n",
    "torch_datas = [torch.from_numpy(i) for i in datas]\n",
    "with torch.no_grad():\n",
    "    scripted_model = torch.jit.trace(torch_model, tuple(torch_datas)).eval()  # type: ignore\n",
    "shape_list = [(\"input\" + str(idx), i) for idx, i in enumerate(input_info)]\n",
    "relay_mod, params = from_pytorch(scripted_model, shape_list)\n",
    "graph, weights = translate.from_relay(relay_mod, params, opt_config=opt_config)\n",
    "# to relax\n",
    "codegen_config = codegen_config or {}\n",
    "codegen_config.update({\"explicit_name\": False, \"from_relay\": True})\n",
    "mod = tvm_codegen.to_relax(graph, weights, codegen_config)\n",
    "if build_target:\n",
    "    build_target = _valid_target(build_target)\n",
    "    if not build_target:\n",
    "        exit()\n",
    "    tvm_datas = [tvm.nd.array(i) for i in datas]\n",
    "    expected_res = _run_relax(expected, build_target, tvm_datas)\n",
    "    if not graph.get_inputs():\n",
    "        tvm_datas = []\n",
    "    res = _run_relax(mod, build_target, tvm_datas)\n",
    "    for exp_r, new_r in zip(expected_res, res):\n",
    "        tvm.testing.assert_allclose(exp_r, new_r, atol=1e-5, rtol=1e-5)\n",
    "else:\n",
    "    tvm.ir.assert_structural_equal(mod, expected)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 转换为 `relax`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tvm.testing\n",
    "from tvm.relax.frontend.torch import from_fx\n",
    "from tvm.contrib.msc.core.frontend import translate\n",
    "from tvm.contrib.msc.framework.tvm import codegen as tvm_codegen\n",
    "\n",
    "\n",
    "def _verify_model(torch_model, input_info, opt_config=None):\n",
    "    graph_model = fx.symbolic_trace(torch_model)\n",
    "    with torch.no_grad():\n",
    "        orig_mod = from_fx(graph_model, input_info)\n",
    "\n",
    "    target = \"llvm\"\n",
    "    dev = tvm.cpu()\n",
    "    args = [tvm.nd.array(np.random.random(size=shape).astype(dtype)) for shape, dtype in input_info]\n",
    "\n",
    "    def _tvm_runtime_to_np(obj):\n",
    "        if isinstance(obj, tvm.runtime.NDArray):\n",
    "            return obj.numpy()\n",
    "        elif isinstance(obj, tvm.runtime.ShapeTuple):\n",
    "            return np.array(obj, dtype=\"int64\")\n",
    "        elif isinstance(obj, (list, tvm.ir.container.Array)):\n",
    "            return [_tvm_runtime_to_np(item) for item in obj]\n",
    "        elif isinstance(obj, tuple):\n",
    "            return tuple(_tvm_runtime_to_np(item) for item in obj)\n",
    "        else:\n",
    "            return obj\n",
    "\n",
    "    def _run_relax(relax_mod):\n",
    "        relax_mod = tvm.relax.transform.LegalizeOps()(relax_mod)\n",
    "        relax_exec = tvm.relax.build(relax_mod, target)\n",
    "        vm_runner = tvm.relax.VirtualMachine(relax_exec, dev)\n",
    "        res = vm_runner[\"main\"](*args)\n",
    "\n",
    "        return _tvm_runtime_to_np(res)\n",
    "\n",
    "    rt_mod = tvm_codegen.to_relax(\n",
    "        *translate.from_relax(orig_mod, opt_config=opt_config),\n",
    "        codegen_config={\"explicit_name\": False},\n",
    "    )\n",
    "\n",
    "    orig_output = _run_relax(orig_mod)\n",
    "    rt_output = _run_relax(rt_mod)\n",
    "    tvm.testing.assert_allclose(orig_output, rt_output)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "_verify_model(torch_model, input_info, opt_config=None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "xin",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
